#! /usr/bin/env python3

import os
import subprocess
import sys
import tarfile
import tempfile

from urllib.request import urlretrieve

from caffe2.python.models.download import (
    deleteDirectory,
    downloadFromURLToFile,
    getURLFromName,
)


class SomeClass:
    # largely copied from
    # https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py
    def _download(self, model):
        model_dir = self._caffe2_model_dir(model)
        assert not os.path.exists(model_dir)
        os.makedirs(model_dir)
        for f in ["predict_net.pb", "init_net.pb", "value_info.json"]:
            url = getURLFromName(model, f)
            dest = os.path.join(model_dir, f)
            try:
                try:
                    downloadFromURLToFile(url, dest, show_progress=False)
                except TypeError:
                    # show_progress not supported prior to
                    # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
                    # (Sep 17, 2017)
                    downloadFromURLToFile(url, dest)
            except Exception as e:
                print(f"Abort: {e}")
                print("Cleaning up...")
                deleteDirectory(model_dir)
                sys.exit(1)

    def _caffe2_model_dir(self, model):
        caffe2_home = os.path.expanduser("~/.caffe2")
        models_dir = os.path.join(caffe2_home, "models")
        return os.path.join(models_dir, model)

    def _onnx_model_dir(self, model):
        onnx_home = os.path.expanduser("~/.onnx")
        models_dir = os.path.join(onnx_home, "models")
        model_dir = os.path.join(models_dir, model)
        return model_dir, os.path.dirname(model_dir)

    # largely copied from
    # https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py
    def _prepare_model_data(self, model):
        model_dir, models_dir = self._onnx_model_dir(model)
        if os.path.exists(model_dir):
            return
        os.makedirs(model_dir)
        url = f"https://s3.amazonaws.com/download.onnx/models/{model}.tar.gz"

        # On Windows, NamedTemporaryFile cannot be opened for a
        # second time
        download_file = tempfile.NamedTemporaryFile(delete=False)
        try:
            download_file.close()
            print(f"Start downloading model {model} from {url}")
            urlretrieve(url, download_file.name)
            print("Done")
            with tarfile.open(download_file.name) as t:
                t.extractall(models_dir)
        except Exception as e:
            print(f"Failed to prepare data for model {model}: {e}")
            raise
        finally:
            os.remove(download_file.name)


models = [
    "bvlc_alexnet",
    "densenet121",
    "inception_v1",
    "inception_v2",
    "resnet50",
    # TODO currently onnx can't translate squeezenet :(
    # 'squeezenet',
    "vgg16",
    # TODO currently vgg19 doesn't work in the CI environment,
    # possibly due to OOM
    # 'vgg19'
]


def download_models():
    sc = SomeClass()
    for model in models:
        print("update-caffe2-models.py:  downloading", model)
        caffe2_model_dir = sc._caffe2_model_dir(model)
        onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
        if not os.path.exists(caffe2_model_dir):
            sc._download(model)
        if not os.path.exists(onnx_model_dir):
            sc._prepare_model_data(model)


def generate_models():
    sc = SomeClass()
    for model in models:
        print("update-caffe2-models.py:  generating", model)
        caffe2_model_dir = sc._caffe2_model_dir(model)
        onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
        subprocess.check_call(["echo", model])
        with open(os.path.join(caffe2_model_dir, "value_info.json"), "r") as f:
            value_info = f.read()
        subprocess.check_call(
            [
                "convert-caffe2-to-onnx",
                "--caffe2-net-name",
                model,
                "--caffe2-init-net",
                os.path.join(caffe2_model_dir, "init_net.pb"),
                "--value-info",
                value_info,
                "-o",
                os.path.join(onnx_model_dir, "model.pb"),
                os.path.join(caffe2_model_dir, "predict_net.pb"),
            ]
        )
        subprocess.check_call(
            ["tar", "-czf", model + ".tar.gz", model], cwd=onnx_models_dir
        )


def upload_models():
    sc = SomeClass()
    for model in models:
        print("update-caffe2-models.py:  uploading", model)
        onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
        subprocess.check_call(
            [
                "aws",
                "s3",
                "cp",
                model + ".tar.gz",
                f"s3://download.onnx/models/{model}.tar.gz",
                "--acl",
                "public-read",
            ],
            cwd=onnx_models_dir,
        )


def cleanup():
    sc = SomeClass()
    for model in models:
        onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
        os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + ".tar.gz"))


if __name__ == "__main__":
    try:
        subprocess.check_call(["aws", "sts", "get-caller-identity"])
    except:
        print(
            "update-caffe2-models.py:  please run `aws configure` manually to set up credentials"
        )
        sys.exit(1)
    if sys.argv[1] == "download":
        download_models()
    if sys.argv[1] == "generate":
        generate_models()
    elif sys.argv[1] == "upload":
        upload_models()
    elif sys.argv[1] == "cleanup":
        cleanup()
